# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import torch

from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer


def get_aiter_mla_metadata(
    max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
) -> tuple[torch.Tensor, ...]:
    paged_kv_indices = torch.zeros(
        max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
    )
    paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
    paged_kv_last_page_lens = torch.full(
        (max_batch_size,), block_size, dtype=torch.int32
    )
    qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
    return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr


def aiter_mla_decode_fwd(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    sm_scale: float,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: torch.Tensor | None = None,
    kv_indices: torch.Tensor | None = None,
    kv_last_page_lens: torch.Tensor | None = None,
    logit_cap: float = 0.0,
):
    torch.ops.vllm.rocm_aiter_mla_decode_fwd(
        q,
        kv_buffer.view(-1, 1, 1, q.shape[-1]),
        o,
        qo_indptr,
        max_seqlen_qo,
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        sm_scale=sm_scale,
        logit_cap=logit_cap,
    )


def mla_decode_fwd_impl(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: torch.Tensor | None = None,
    kv_indices: torch.Tensor | None = None,
    kv_last_page_lens: torch.Tensor | None = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    from aiter.mla import mla_decode_fwd

    mla_decode_fwd(
        q,
        kv_buffer.view(-1, 1, 1, q.shape[-1]),
        o,
        qo_indptr,
        kv_indptr,
        kv_indices,
        kv_last_page_lens,
        max_seqlen_qo,
        sm_scale=sm_scale,
        logit_cap=logit_cap,
    )


def mla_decode_fwd_fake(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: torch.Tensor | None = None,
    kv_indices: torch.Tensor | None = None,
    kv_last_page_lens: torch.Tensor | None = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    pass


if current_platform.is_rocm():
    if is_torch_equal_or_newer("2.7.0"):
        tags = ()
    else:
        tags = ((torch.Tag.needs_fixed_stride_order,),)
    direct_register_custom_op(
        op_name="rocm_aiter_mla_decode_fwd",
        op_func=mla_decode_fwd_impl,
        mutates_args=["o"],
        fake_impl=mla_decode_fwd_fake,
        tags=tags,
    )
